#!/usr/bin/python3 -u
import os, os.path, sys, stat, signal, errno, argparse, time, json, re

KILL_PROCESS_TIMEOUT = 5
KILL_ALL_PROCESSES_TIMEOUT = 5

LOG_LEVEL_ERROR = 1
LOG_LEVEL_WARN  = 1
LOG_LEVEL_INFO  = 2
LOG_LEVEL_DEBUG = 3

log_level = None

terminated_child_processes = {}

class AlarmException(Exception):
	pass

def error(message):
	if log_level >= LOG_LEVEL_ERROR:
		sys.stderr.write("*** %s\n" % message)

def warn(message):
	if log_level >= LOG_LEVEL_WARN:
		sys.stderr.write("*** %s\n" % message)

def info(message):
	if log_level >= LOG_LEVEL_INFO:
		sys.stderr.write("*** %s\n" % message)

def debug(message):
	if log_level >= LOG_LEVEL_DEBUG:
		sys.stderr.write("*** %s\n" % message)

def ignore_signals_and_raise_keyboard_interrupt(signame):
	signal.signal(signal.SIGTERM, signal.SIG_IGN)
	signal.signal(signal.SIGINT, signal.SIG_IGN)
	raise KeyboardInterrupt(signame)

def raise_alarm_exception():
	raise AlarmException('Alarm')

def listdir(path):
	try:
		result = os.stat(path)
	except OSError:
		return []
	if stat.S_ISDIR(result.st_mode):
		return sorted(os.listdir(path))
	else:
		return []

def is_exe(path):
	try:
		return os.path.isfile(path) and os.access(path, os.X_OK)
	except OSError:
		return False

def import_envvars(clear_existing_environment = True, override_existing_environment = True):
	new_env = {}
	for envfile in listdir("/etc/container_environment"):
		name = os.path.basename(envfile)
		with open("/etc/container_environment/" + envfile, "r") as f:
			# Text files often end with a trailing newline, which we
			# don't want to include in the env variable value. See
			# https://github.com/phusion/baseimage-docker/pull/49
			value = re.sub('\n\Z', '', f.read())
		new_env[name] = value
	if clear_existing_environment:
		os.environ.clear()
	for name, value in new_env.items():
		if override_existing_environment or not name in os.environ:
			os.environ[name] = value

def export_envvars(to_dir = True):
	shell_dump = ""
	for name, value in os.environ.items():
		if name in ['HOME', 'USER', 'GROUP', 'UID', 'GID', 'SHELL']:
			continue
		if to_dir:
			with open("/etc/container_environment/" + name, "w") as f:
				f.write(value)
		shell_dump += "export " + shquote(name) + "=" + shquote(value) + "\n"
	with open("/etc/container_environment.sh", "w") as f:
		f.write(shell_dump)
	with open("/etc/container_environment.json", "w") as f:
		f.write(json.dumps(dict(os.environ)))

_find_unsafe = re.compile(r'[^\w@%+=:,./-]').search

def shquote(s):
	"""Return a shell-escaped version of the string *s*."""
	if not s:
		return "''"
	if _find_unsafe(s) is None:
		return s

	# use single quotes, and put single quotes into double quotes
	# the string $'b is then quoted as '$'"'"'b'
	return "'" + s.replace("'", "'\"'\"'") + "'"

# Waits for the child process with the given PID, while at the same time
# reaping any other child processes that have exited (e.g. adopted child
# processes that have terminated).
def waitpid_reap_other_children(pid):
	global terminated_child_processes

	status = terminated_child_processes.get(pid)
	if status:
		# A previous call to waitpid_reap_other_children(),
		# with an argument not equal to the current argument,
		# already waited for this process. Return the status
		# that was obtained back then.
		del terminated_child_processes[pid]
		return status

	done = False
	status = None
	while not done:
		try:
			this_pid, status = os.waitpid(-1, 0)
			if this_pid == pid:
				done = True
			else:
				# Save status for later.
				terminated_child_processes[this_pid] = status
		except OSError as e:
			if e.errno == errno.ECHILD or e.errno == errno.ESRCH:
				return None
			else:
				raise
	return status

def stop_child_process(name, pid, signo = signal.SIGTERM, time_limit = KILL_PROCESS_TIMEOUT):
	info("Shutting down %s (PID %d)..." % (name, pid))
	try:
		os.kill(pid, signo)
	except OSError:
		pass
	signal.alarm(time_limit)
	try:
		try:
			waitpid_reap_other_children(pid)
		except OSError:
			pass
	except AlarmException:
		warn("%s (PID %d) did not shut down in time. Forcing it to exit." % (name, pid))
		try:
			os.kill(pid, signal.SIGKILL)
		except OSError:
			pass
		try:
			waitpid_reap_other_children(pid)
		except OSError:
			pass
	finally:
		signal.alarm(0)

def run_command_killable(*argv):
	filename = argv[0]
	status = None
	pid = os.spawnvp(os.P_NOWAIT, filename, argv)
	try:
		status = waitpid_reap_other_children(pid)
	except BaseException as s:
		warn("An error occurred. Aborting.")
		stop_child_process(filename, pid)
		raise
	if status != 0:
		if status is None:
			error("%s exited with unknown status\n" % filename)
		else:
			error("%s failed with status %d\n" % (filename, os.WEXITSTATUS(status)))
		sys.exit(1)

def run_command_killable_and_import_envvars(*argv):
	run_command_killable(*argv)
	import_envvars()
	export_envvars(False)

def kill_all_processes(time_limit):
	info("Killing all processes...")
	try:
		os.kill(-1, signal.SIGTERM)
	except OSError:
		pass
	signal.alarm(time_limit)
	try:
		# Wait until no more child processes exist.
		done = False
		while not done:
			try:
				os.waitpid(-1, 0)
			except OSError as e:
				if e.errno == errno.ECHILD:
					done = True
				else:
					raise
	except AlarmException:
		warn("Not all processes have exited in time. Forcing them to exit.")
		try:
			os.kill(-1, signal.SIGKILL)
		except OSError:
			pass
	finally:
		signal.alarm(0)

def run_startup_files():
	# Run /etc/my_init.d/*
	for name in listdir("/etc/my_init.d"):
		filename = "/etc/my_init.d/" + name
		if is_exe(filename):
			info("Running %s..." % filename)
			run_command_killable_and_import_envvars(filename)

	# Run /etc/rc.local.
	if is_exe("/etc/rc.local"):
		info("Running /etc/rc.local...")
		run_command_killable_and_import_envvars("/etc/rc.local")

def start_runit():
	info("Booting runit daemon...")
	pid = os.spawnl(os.P_NOWAIT, "/usr/bin/runsvdir", "/usr/bin/runsvdir",
		"-P", "/etc/service")
	info("Runit started as PID %d" % pid)
	return pid

def wait_for_runit_or_interrupt(pid):
	try:
		status = waitpid_reap_other_children(pid)
		return (True, status)
	except KeyboardInterrupt:
		return (False, None)

def shutdown_runit_services():
	debug("Begin shutting down runit services...")
	os.system("/usr/bin/sv down /etc/service/*")

def wait_for_runit_services():
	debug("Waiting for runit services to exit...")
	done = False
	while not done:
		done = os.system("/usr/bin/sv status /etc/service/* | grep -q '^run:'") != 0
		if not done:
			time.sleep(0.1)

def install_insecure_key():
	info("Installing insecure SSH key for user root")
	run_command_killable("/usr/sbin/enable_insecure_key")

def main(args):
	import_envvars(False, False)
	export_envvars()

	if args.enable_insecure_key:
		install_insecure_key()

	if not args.skip_startup_files:
		run_startup_files()
	
	runit_exited = False
	exit_code = None

	if not args.skip_runit:
		runit_pid = start_runit()
	try:
		exit_status = None
		if len(args.main_command) == 0:
			runit_exited, exit_code = wait_for_runit_or_interrupt(runit_pid)
			if runit_exited:
				if exit_code is None:
					info("Runit exited with unknown status")
					exit_status = 1
				else:
					exit_status = os.WEXITSTATUS(exit_code)
					info("Runit exited with status %d" % exit_status)
		else:
			info("Running %s..." % " ".join(args.main_command))
			pid = os.spawnvp(os.P_NOWAIT, args.main_command[0], args.main_command)
			try:
				exit_code = waitpid_reap_other_children(pid)
				if exit_code is None:
					info("%s exited with unknown status." % args.main_command[0])
					exit_status = 1
				else:
					exit_status = os.WEXITSTATUS(exit_code)
					info("%s exited with status %d." % (args.main_command[0], exit_status))
			except KeyboardInterrupt:
				stop_child_process(args.main_command[0], pid)
				raise
			except BaseException as s:
				warn("An error occurred. Aborting.")
				stop_child_process(args.main_command[0], pid)
				raise
		sys.exit(exit_status)
	finally:
		if not args.skip_runit:
			shutdown_runit_services()
			if not runit_exited:
				stop_child_process("runit daemon", runit_pid)
			wait_for_runit_services()

# Parse options.
parser = argparse.ArgumentParser(description = 'Initialize the system.')
parser.add_argument('main_command', metavar = 'MAIN_COMMAND', type = str, nargs = '*',
	help = 'The main command to run. (default: runit)')
parser.add_argument('--enable-insecure-key', dest = 'enable_insecure_key',
	action = 'store_const', const = True, default = False,
	help = 'Install the insecure SSH key')
parser.add_argument('--skip-startup-files', dest = 'skip_startup_files',
	action = 'store_const', const = True, default = False,
	help = 'Skip running /etc/my_init.d/* and /etc/rc.local')
parser.add_argument('--skip-runit', dest = 'skip_runit',
	action = 'store_const', const = True, default = False,
	help = 'Do not run runit services')
parser.add_argument('--no-kill-all-on-exit', dest = 'kill_all_on_exit',
	action = 'store_const', const = False, default = True,
	help = 'Don\'t kill all processes on the system upon exiting')
parser.add_argument('--quiet', dest = 'log_level',
	action = 'store_const', const = LOG_LEVEL_WARN, default = LOG_LEVEL_INFO,
	help = 'Only print warnings and errors')
args = parser.parse_args()
log_level = args.log_level

if args.skip_runit and len(args.main_command) == 0:
	error("When --skip-runit is given, you must also pass a main command.")
	sys.exit(1)

# Run main function.
signal.signal(signal.SIGTERM, lambda signum, frame: ignore_signals_and_raise_keyboard_interrupt('SIGTERM'))
signal.signal(signal.SIGINT, lambda signum, frame: ignore_signals_and_raise_keyboard_interrupt('SIGINT'))
signal.signal(signal.SIGALRM, lambda signum, frame: raise_alarm_exception())
try:
	main(args)
except KeyboardInterrupt:
	warn("Init system aborted.")
	exit(2)
finally:
	if args.kill_all_on_exit:
		kill_all_processes(KILL_ALL_PROCESSES_TIMEOUT)
